package pbe

import (
	"bytes"
	"crypto/cipher"
	"crypto/des"
	"crypto/md5"
	"encoding/base64"
	"strings"
)

func getDerivedKey(password string, salt []byte, iterations int) ([]byte, []byte) {
	key := md5.Sum([]byte(password + string(salt)))
	for i := 0; i < iterations-1; i++ {
		key = md5.Sum(key[:])
	}
	return key[:8], key[8:]
}

func Encrypt(password string, salt []byte, iterations int, plainText string) (string, error) {
	padNum := byte(8 - len(plainText)%8)
	for i := byte(0); i < padNum; i++ {
		plainText += string(padNum)
	}

	dk, iv := getDerivedKey(password, salt, iterations)

	block, err := des.NewCipher(dk)

	if err != nil {
		return "", err
	}

	encrypter := cipher.NewCBCEncrypter(block, iv)
	encrypted := make([]byte, len(plainText))
	encrypter.CryptBlocks(encrypted, []byte(plainText))

	return base64.StdEncoding.EncodeToString(encrypted), nil
}

func EncryptBytes(password string, salt []byte, iterations int, plainData []byte) (dst []byte, err error) {
	padNum := byte(8 - len(plainData)%8)
	for i := byte(0); i < padNum; i++ {
		plainData = append(plainData, padNum)
	}

	dk, iv := getDerivedKey(password, salt, iterations)

	block, err := des.NewCipher(dk)
	if err != nil {
		return
	}

	encrypter := cipher.NewCBCEncrypter(block, iv)
	dst = make([]byte, len(plainData))
	encrypter.CryptBlocks(dst, plainData)

	return
}

func Decrypt(password string, salt []byte, iterations int, cipherText string) (string, error) {
	msgBytes, err := base64.StdEncoding.DecodeString(cipherText)
	if err != nil {
		return "", err
	}

	dk, iv := getDerivedKey(password, salt, iterations)
	block, err := des.NewCipher(dk)

	if err != nil {
		return "", err
	}

	decrypter := cipher.NewCBCDecrypter(block, iv)
	decrypted := make([]byte, len(msgBytes))
	decrypter.CryptBlocks(decrypted, msgBytes)

	decryptedString := strings.TrimRight(string(decrypted), "\x01\x02\x03\x04\x05\x06\x07\x08")
	return decryptedString, nil
}

func DecryptBytes(password string, salt []byte, iterations int, cipherData []byte) (dst []byte, err error) {
	dk, iv := getDerivedKey(password, salt, iterations)
	block, err := des.NewCipher(dk)
	if err != nil {
		return
	}

	decrypter := cipher.NewCBCDecrypter(block, iv)
	dst = make([]byte, len(cipherData))
	decrypter.CryptBlocks(dst, cipherData)

	dst = bytes.TrimRight(dst, "\x01\x02\x03\x04\x05\x06\x07\x08")
	return
}
